## Feed Forward Style Transfer

In [None]:
import os, random, time, platform
from pathlib import Path
from PIL import Image, ImageOps
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, models, utils
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
import gdown

In [None]:
#select computation device
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    USE_AMP = True
elif getattr(torch, "has_mps", False) and torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    USE_AMP = False
else:
    DEVICE = torch.device("cpu")
    USE_AMP = False

print("Device:", DEVICE, "USE_AMP:", USE_AMP)

Device: mps USE_AMP: False


  elif getattr(torch, "has_mps", False) and torch.backends.mps.is_available():


In [None]:
#Hyperparameters
IMG_SIZE = 512       
BATCH_SIZE = 4
NUM_EPOCHS = 10
LR = 1e-3 #changed
CONTENT_WEIGHT = 0.5 #
STYLE_WEIGHT   = 5e6   
TV_WEIGHT      = 0 #for smoothness

In [None]:
#fielpaths
CONTENT_ROOT = "../Data/dataset/clean/animals_balanced"   
STYLE_ROOT= "../Data/dataset/clean/origami_images"
SPLIT_ROOT   = "../Data/dataset/split"   
CHECKPOINT_DIR = "./checkpoints_nststyle"
SAMPLES_DIR    = "./samples_nststyle"

for d in [SPLIT_ROOT, CHECKPOINT_DIR, SAMPLES_DIR]:
    os.makedirs(d, exist_ok=True)
    
TARGET_CLASS = "butterfly" #single class (inital)

for split in ['train', 'val', 'test']:
    for root in ['content', 'style']:
        path = os.path.join(SPLIT_ROOT, root, split, TARGET_CLASS)
        os.makedirs(path, exist_ok=True)
        

In [None]:
IMG_MEAN = [0.485, 0.456, 0.406]
IMG_STD  = [0.229, 0.224, 0.225]

def list_images(dir_):
    return sorted([os.path.join(dir_,f) for f in os.listdir(dir_)
                   if f.lower().endswith(('.jpg','.jpeg','.png'))])
    
def exif_fix_and_open(path):
    img = Image.open(path)
    img = ImageOps.exif_transpose(img)
    return img.convert("RGB")

def simple_bar(step, total, epoch, loss=None, bar_len=20):
    pct = step / total
    filled = int(bar_len * pct)
    bar = "=" * filled + "." * (bar_len - filled)
    msg = f"\rEpoch {epoch}: {bar} {pct*100:5.1f}%"
    if loss is not None:
        msg += f" | loss: {loss:.4f}"
    print(msg, end="", flush=True)
    if step == total:
        print()  # newline at end


In [None]:
#same as nst.py
def normalize_for_vgg(x):
    mean = torch.tensor(IMG_MEAN).view(1,3,1,1).to(DEVICE)
    std  = torch.tensor(IMG_STD).view(1,3,1,1).to(DEVICE)
    return (x - mean) / std

# def extract_features_batch(x, layers, model):
#     x_vgg = normalize_for_vgg(x)
#     cur = x_vgg
#     features = {}
#     layers_to_extract = {LAYER_INDICES[name]: name for name in layers}
#     for idx, layer in model._modules.items():
#         cur = layer(cur)
#         if idx in layers_to_extract:
#             features[layers_to_extract[idx]] = cur
#     return features

def gram_matrix_batch(tensor):
    b, c, h, w = tensor.size()
    f = tensor.view(b, c, h*w)
    return torch.bmm(f, f.transpose(1,2)) / (c * h * w)


#### Data Splitting

In [None]:
def ensure_splits(class_name, val_frac=0.1, test_frac=0.05, seed=42):
    random.seed(seed)
    for domain, root in [("content", CONTENT_ROOT), ("style", STYLE_ROOT)]:
        src = os.path.join(root, class_name)
        assert os.path.isdir(src), f"Missing {domain} folder: {src}"
        files = [f for f in os.listdir(src) if f.lower().endswith(('.jpg','.jpeg','.png'))]
        random.shuffle(files)
        n = len(files); n_val=int(n*val_frac); n_test=int(n*test_frac)
        splits = {
            "train": files[n_val+n_test:],
            "val":   files[:n_val],
            "test":  files[n_val:n_val+n_test]
        }
        for split, names in splits.items():
            out = os.path.join(SPLIT_ROOT, domain, split, class_name)
            os.makedirs(out, exist_ok=True)
            for f in names:
                srcf = os.path.join(src, f); dstf = os.path.join(out, f)
                if not os.path.exists(dstf):
                    exif_fix_and_open(srcf).save(dstf, "JPEG", quality=90)
    print(f"splits ready under {SPLIT_ROOT}/{class_name}")

ensure_splits(TARGET_CLASS)

splits ready under ../Data/dataset/split/butterfly


#### VGG Layer Configs and Normalization

In [None]:
# #Vsame as NST
# LAYER_INDICES = {
#     'conv1_1': '0', 
#     'conv1_2': '2', 
#     'conv2_1': '5', 
#     'conv2_2': '7',
#     'conv3_1': '10', 
#     'conv3_2': '12', 
#     'conv3_3': '14', 
#     'conv3_4': '16',
#     'conv4_1': '19', 
#     'conv4_2': '21', 
#     'conv4_3': '23', 
#     'conv4_4': '25',
#     'conv5_1': '28', 
#     'conv5_2': '30', 
#     'conv5_3': '32', 
#     'conv5_4': '34'
# }

# LAYER_CONFIGS = {
#     'gatys': {
#         'content': ['conv4_2'],
#         'style': ['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1'],
#         'style_weights': {
#             'conv1_1': 1.0,
#             'conv2_1': 0.8,
#             'conv3_1': 0.5,
#             'conv4_1': 0.3,
#             'conv5_1': 0.1
#         },
#     }
# }
# ACTIVE_LAYER_CONFIG = 'gatys'


#### Normalization for VGG and Feature Extraction

#### Data Sampling

In [None]:
train_tf = transforms.Compose([transforms.Resize(IMG_SIZE),
                               transforms.CenterCrop(IMG_SIZE),
                               transforms.ToTensor()])


In [None]:
content_train_dir = os.path.join(SPLIT_ROOT, "content", "train", TARGET_CLASS)
style_train_dir   = os.path.join(SPLIT_ROOT, "style",   "train", TARGET_CLASS)
CONTENT_FILES = list_images(content_train_dir)
STYLE_FILES   = list_images(style_train_dir)
assert len(CONTENT_FILES)>0 and len(STYLE_FILES)>0, "need images to train"

def sample_content_batch(batch_size):
    paths = random.choices(CONTENT_FILES, k=batch_size)
    tensors = [train_tf(exif_fix_and_open(p)) for p in paths]
    return torch.stack(tensors).to(DEVICE)

#### Load Pre-Trained VGG

In [None]:
LAYER_INDICES = {
    'conv1_1': 0,  'conv1_2': 2,
    'conv2_1': 5,  'conv2_2': 7,
    'conv3_1': 10, 'conv3_2': 12, 'conv3_3': 14, 'conv3_4': 16,
    'conv4_1': 19, 'conv4_2': 21, 'conv4_3': 23, 'conv4_4': 25,
    'conv5_1': 28, 'conv5_2': 30, 'conv5_3': 32, 'conv5_4': 34
}

CONTENT_LAYERS = ['conv4_2']
STYLE_LAYERS   = ['conv1_1','conv2_1','conv3_1','conv4_1','conv5_1','conv4_2']
STYLE_LAYER_WEIGHTS = {
    'conv1_1': 1.0,
    'conv2_1': 0.8,
    'conv3_1': 0.6,
    'conv4_1': 0.4,
    'conv5_1': 0.2,
    'conv4_2': 0.2,   # adds structural push
}
vgg = models.vgg19(pretrained=True).features.to(DEVICE).eval()
for p in vgg.parameters(): p.requires_grad = False

class VGGFeatureExtractor(nn.Module):
    def __init__(self, vgg, layer_indices):
        super().__init__()
        self.vgg = vgg
        self.idx_to_name = {idx: name for name, idx in layer_indices.items()}
        self.watch = set(layer_indices.values())
    def forward(self, x):
        feats = {}
        cur = x
        for i, layer in self.vgg._modules.items():
            i = int(i)
            cur = layer(cur)
            if i in self.watch:
                feats[self.idx_to_name[i]] = cur
        return feats

vgg_feat = VGGFeatureExtractor(vgg, LAYER_INDICES).to(DEVICE).eval()



#### Transformer Network

In [None]:
class ConvLayer(nn.Module):
    def __init__(self, in_c, out_c, k, s):
        super().__init__()
        self.pad = nn.ReflectionPad2d(k//2)
        self.conv = nn.Conv2d(in_c, out_c, k, s, 0)
        self.inorm = nn.InstanceNorm2d(out_c, affine=False)
    def forward(self, x):
        return F.relu(self.inorm(self.conv(self.pad(x))))

In [None]:
# class ResidualBlock(nn.Module): #learn style modifications
#     def __init__(self, channels):
#         super().__init__()
#         self.conv1 = nn.Conv2d(channels, channels, 3, 1, 1)
#         self.in1 = nn.InstanceNorm2d(channels, affine=True)
#         self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1)
#         self.in2 = nn.InstanceNorm2d(channels, affine=True)
#     def forward(self, x):
#         out = F.relu(self.in1(self.conv1(x)))
#         out = self.in2(self.conv2(out))
#         return out + x

# class StylizedResidualBlock(nn.Module):
#     def __init__(self, channels):
#         super().__init__()
#         self.conv1 = nn.Conv2d(channels, channels, 3, 1, 1)
#         self.in1 = nn.InstanceNorm2d(channels, affine=True)
#         self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1)
#         self.in2 = nn.InstanceNorm2d(channels, affine=True)
        
#         # style gate to enhance stylized contrast/edges
#         self.style_gate = nn.Sequential(
#             nn.Conv2d(channels, channels, 1),
#             nn.Sigmoid()
#         )

#     def forward(self, x):
#         out = F.relu(self.in1(self.conv1(x)))
#         out = self.in2(self.conv2(out))
#         gate = self.style_gate(out)
#         # modulate residual with learned style gate
#         out = out * gate + x
        
#         return out

class ResidualBlock(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(c, c, 3, 1, 0),
            nn.InstanceNorm2d(c, affine=True),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(c, c, 3, 1, 0),
            nn.InstanceNorm2d(c, affine=False),
        )
    def forward(self, x): return x + self.block(x)


In [None]:
# class UpsampleConv(nn.Module): # upsampling the image (making it bigger)
#     def __init__(self, in_c, out_c, kernel, upsample=None):
#         super().__init__()
#         self.upsample = upsample
#         padding = kernel // 2
#         self.conv = nn.Conv2d(in_c, out_c, kernel, 1, padding)
#         self.inorm = nn.InstanceNorm2d(out_c, affine=True)
        
#     def forward(self, x):
#         if self.upsample:
#             x = F.interpolate(x, scale_factor=self.upsample, mode='nearest')
            
#         return F.relu(self.inorm(self.conv(x)))
class UpNearestConv(nn.Module):
    def __init__(self, in_c, out_c, scale=2, k=3):
        super().__init__()
        self.scale = scale
        self.conv = nn.Conv2d(in_c, out_c, k, 1, padding=k//2)
        self.inorm = nn.InstanceNorm2d(out_c, affine=False)
    def forward(self, x):
        x = F.interpolate(x, scale_factor=self.scale, mode='nearest')
        x = self.conv(x)
        return F.relu(self.inorm(x))

In [None]:
# class TransformerNet(nn.Module):
    
#     def __init__(self):
#         super().__init__()
#         self.conv1 = ConvLayer(3, 32, 9, 1)
#         self.conv2 = ConvLayer(32, 64, 3, 2)
#         self.conv3 = ConvLayer(64, 128, 3, 2)
#         self.res1 = ResidualBlock(128)
#         self.res2 = ResidualBlock(128)
#         self.res3 = ResidualBlock(128)
#         self.res4 = ResidualBlock(128)
#         self.res5 = ResidualBlock(128)
#         self.up1 = UpsampleConv(128, 64, 3, upsample=2)
#         self.up2 = UpsampleConv(64, 32, 3, upsample=2)
#         self.conv_out = nn.Conv2d(32, 3, 9, 1, 4)
        
#     def forward(self, x):
#         y = self.conv1(x)
#         y = self.conv2(y)
#         y = self.conv3(y)
#         y = self.res1(y)
#         y = self.res2(y)
#         y = self.res3(y)
#         y = self.res4(y)
#         y = self.res5(y)
#         y = self.up1(y)
#         y = self.up2(y)
#         y = self.conv_out(y)
#         return torch.sigmoid(y)

# class TransformerNet(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.conv1 = ConvLayer(3, 32, 9, 1)
#         self.conv2 = ConvLayer(32, 64, 3, 2)
#         self.conv3 = ConvLayer(64, 128, 3, 2)

#         #stylized Residuals
#         self.res1 = StylizedResidualBlock(128)
#         self.res2 = StylizedResidualBlock(128)
#         self.res3 = StylizedResidualBlock(128)
#         self.res4 = StylizedResidualBlock(128)
#         self.res5 = StylizedResidualBlock(128)

#         self.up1 = UpsampleConv(128, 64, 3, upsample=2)
#         self.up2 = UpsampleConv(64, 32, 3, upsample=2)
#         self.conv_out = nn.Conv2d(32, 3, 9, 1, 4)

#     def forward(self, x):
#         y = self.conv1(x)
#         y = self.conv2(y)
#         y = self.conv3(y)
#         y = self.res1(y)
#         y = self.res2(y)
#         y = self.res3(y)
#         y = self.res4(y)
#         y = self.res5(y)
#         y = self.up1(y)
#         y = self.up2(y)
#         y = self.conv_out(y)
#         return torch.sigmoid(y)

class TransformerNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.c1 = ConvLayer(3,   32, 9, 1)
        self.c2 = ConvLayer(32,  64, 3, 2)
        self.c3 = ConvLayer(64, 128, 3, 2)
        self.r1 = ResidualBlock(128); self.r2 = ResidualBlock(128); self.r3 = ResidualBlock(128)
        self.r4 = ResidualBlock(128); self.r5 = ResidualBlock(128)
        self.u1 = UpNearestConv(128, 64)
        self.u2 = UpNearestConv(64,  32)

        self.pad_out = nn.ReflectionPad2d(4)
        self.conv_out = nn.Conv2d(32, 3, 9, 1, 0)
    def forward(self, x):
        y = self.c1(x); y = self.c2(y); y = self.c3(y)
        y = self.r1(y); y = self.r2(y); y = self.r3(y); y = self.r4(y); y = self.r5(y)
        y = self.u1(y); y = self.u2(y)
        y = self.pad_out(y); y = self.conv_out(y)
        return torch.sigmoid(y)

In [None]:
model = TransformerNet().to(DEVICE)
opt = optim.Adam(model.parameters(), lr=LR)

def tv_loss_fn(x):
    return torch.mean(torch.abs(x[:,:,:,1:] - x[:,:,:,:-1])) + \
           torch.mean(torch.abs(x[:,:,1:,:] - x[:,:,:-1,:]))

In [None]:
def precompute_style_grams(style_dir, transform, chunk=8):
    files = list_images(style_dir)
    grams = []
    with torch.no_grad():
        for i in range(0, len(files), chunk):
            paths = files[i:i+chunk]
            batch = torch.stack([transform(exif_fix_and_open(p)) for p in paths]).to(DEVICE)
            feats = vgg_feat(normalize_for_vgg(batch))
            for j in range(batch.size(0)):
                grams.append({l: gram_matrix_batch(feats[l][j:j+1]).cpu() for l in feats.keys()})
    print(f"Precomputed {len(grams)} style gram dicts.")
    return grams

style_dir = os.path.join(STYLE_ROOT, TARGET_CLASS)
# recompute grams at 512
style_grams = precompute_style_grams(style_dir, 
    transforms.Compose([transforms.Resize(IMG_SIZE),
                        transforms.CenterCrop(IMG_SIZE),
                        transforms.ToTensor()]),
    chunk=8
)


KeyboardInterrupt: 

In [None]:
scaler = torch.cuda.amp.GradScaler() if (USE_AMP and DEVICE.type=="cuda") else None
steps_per_epoch = max(1000, len(CONTENT_FILES)//BATCH_SIZE)

print("Training (no dataloader)…")
for epoch in range(1, NUM_EPOCHS+1):
    model.train()
    running = 0.0
    for step in range(1, steps_per_epoch+1):
        content_batch = sample_content_batch(BATCH_SIZE)
        opt.zero_grad(set_to_none=True)

        if scaler is not None:
            with torch.cuda.amp.autocast():
                out = model(content_batch)
                c_norm, o_norm = normalize_for_vgg(content_batch), normalize_for_vgg(out)
                c_feats, o_feats = vgg_feat(c_norm), vgg_feat(o_norm)
                c_loss = torch.mean((o_feats[CONTENT_LAYERS[0]] - c_feats[CONTENT_LAYERS[0]])**2)
                Gs = random.choice(style_grams); s_loss = 0.0
                for l in STYLE_LAYERS:
                    Go = gram_matrix_batch(o_feats[l])
                    s_loss += STYLE_LAYER_WEIGHTS[l]*torch.mean((Go - Gs[l].to(DEVICE))**2)
                tv = tv_loss_fn(out)
                total = CONTENT_WEIGHT*c_loss + STYLE_WEIGHT*s_loss + TV_WEIGHT*tv
            scaler.scale(total).backward(); scaler.step(opt); scaler.update()
        else:
            out = model(content_batch)
            c_norm, o_norm = normalize_for_vgg(content_batch), normalize_for_vgg(out)
            c_feats, o_feats = vgg_feat(c_norm), vgg_feat(o_norm)
            c_loss = torch.mean((o_feats[CONTENT_LAYERS[0]] - c_feats[CONTENT_LAYERS[0]])**2)
            Gs = random.choice(style_grams); s_loss = 0.0
            for l in STYLE_LAYERS:
                Go = gram_matrix_batch(o_feats[l])
                s_loss += STYLE_LAYER_WEIGHTS[l]*torch.mean((Go - Gs[l].to(DEVICE))**2)
            tv = tv_loss_fn(out)
            total = CONTENT_WEIGHT*c_loss + STYLE_WEIGHT*s_loss + TV_WEIGHT*tv
            total.backward(); opt.step()

        running += total.item()
        
        simple_bar(step, steps_per_epoch, epoch, loss=total.item())
         
        if step % 200 == 0:
            avg = running/200; running = 0.0
            print(f"E{epoch} S{step}/{steps_per_epoch} | loss {avg:.4f}")
            model.eval()
            with torch.no_grad():
                utils.save_image(out[:1].cpu(), f"{SAMPLES_DIR}/ep{epoch}_st{step}.png")
            model.train()

    torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, f"johnson_nold_epoch{epoch}.pth"))
    print(f"[E{epoch}] checkpoint saved.")


Training (no dataloader)…
Epoch 1: ====................  20.0% | loss: 24.14405E1 S200/1000 | loss 26.4994
E1 S1000/1000 | loss 15.2456
[E1] checkpoint saved.
Epoch 2: ====................  20.0% | loss: 18.7341E2 S200/1000 | loss 14.3131
E2 S1000/1000 | loss 11.1061
[E2] checkpoint saved.
Epoch 3: ====................  20.0% | loss: 10.1141E3 S200/1000 | loss 11.5658
E3 S1000/1000 | loss 10.4699
[E3] checkpoint saved.
Epoch 4: ====................  20.0% | loss: 5.40036E4 S200/1000 | loss 11.5797
E4 S1000/1000 | loss 10.4822
[E4] checkpoint saved.
Epoch 5: ====................  20.0% | loss: 27.7538E5 S200/1000 | loss 10.9001
E5 S1000/1000 | loss 9.4419
[E5] checkpoint saved.
Epoch 6: ====................  20.0% | loss: 11.0864E6 S200/1000 | loss 11.3464
E6 S1000/1000 | loss 10.6915
[E6] checkpoint saved.


#### Testing for 1 image

In [None]:
model.eval()
test_img_path = "test_imgs/butterfly.jpg"
if os.path.exists(test_img_path):
    x = train_tf(exif_fix_and_open(test_img_path)).unsqueeze(0).to(DEVICE)
    with torch.no_grad(): y = model(x).cpu()
    utils.save_image(y, os.path.join(SAMPLES_DIR, "stylized_image.png"))
    print("Saved stylized image to", os.path.join(SAMPLES_DIR, "stylized_image.png"))
else:
    print("Test image not found at", test_img_path)

Saved stylized image to ./samples_nststyle/stylized_image.png
