## Feed Forward Style Transfer

In [1]:
import os
import time
import random
from PIL import Image, ImageOps
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, models, utils

In [2]:
#config same to nst
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#Hyperparameters
IMG_SIZE = 128    #reduced to 128 for quick check(can use 256/512 later)       
BATCH_SIZE = 6
NUM_EPOCHS = 3
LR = 1e-3
CONTENT_WEIGHT = 1.0
STYLE_WEIGHT   = 1e6   
TV_WEIGHT      = 1e-6

In [3]:
#fielpaths
CONTENT_ROOT = "../Data/dataset/clean/animals_balanced"   
STYLE_ROOT= "../Data/dataset/clean/origami_images"
SPLIT_ROOT   = "../Data/dataset/split"   
CHECKPOINT_DIR = "./checkpoints_nststyle"
SAMPLES_DIR    = "./samples_nststyle"

for d in [SPLIT_ROOT, CHECKPOINT_DIR, SAMPLES_DIR]:
    os.makedirs(d, exist_ok=True)

TARGET_CLASS = "butterfly" #single class (inital)

#### VGG Layer Configs and Normalization

In [4]:
#Vsame as NST
LAYER_INDICES = {
    'conv1_1': '0', 
    'conv1_2': '2', 
    'conv2_1': '5', 
    'conv2_2': '7',
    'conv3_1': '10', 
    'conv3_2': '12', 
    'conv3_3': '14', 
    'conv3_4': '16',
    'conv4_1': '19', 
    'conv4_2': '21', 
    'conv4_3': '23', 
    'conv4_4': '25',
    'conv5_1': '28', 
    'conv5_2': '30', 
    'conv5_3': '32', 
    'conv5_4': '34'
}

LAYER_CONFIGS = {
    'gatys': {
        'content': ['conv4_2'],
        'style': ['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1'],
        'style_weights': {
            'conv1_1': 1.0,
            'conv2_1': 0.8,
            'conv3_1': 0.5,
            'conv4_1': 0.3,
            'conv5_1': 0.1
        },
    }
}
ACTIVE_LAYER_CONFIG = 'gatys'

#normalization
IMG_MEAN = [0.485, 0.456, 0.406]
IMG_STD  = [0.229, 0.224, 0.225]


#### Normalization for VGG and Feature Extraction

In [5]:
def exif_fix_and_open(path):
    img = Image.open(path)
    img = ImageOps.exif_transpose(img)
    return img.convert("RGB")

#same as nst.py
def normalize_for_vgg(x):
    mean = torch.tensor(IMG_MEAN).view(1,3,1,1).to(DEVICE)
    std  = torch.tensor(IMG_STD).view(1,3,1,1).to(DEVICE)
    return (x - mean) / std

def extract_features_batch(x, layers, model):
    x_vgg = normalize_for_vgg(x)
    cur = x_vgg
    features = {}
    layers_to_extract = {LAYER_INDICES[name]: name for name in layers}
    for idx, layer in model._modules.items():
        cur = layer(cur)
        if idx in layers_to_extract:
            features[layers_to_extract[idx]] = cur
    return features

def gram_matrix_batch(tensor):
    b, c, h, w = tensor.size()
    f = tensor.view(b, c, h*w)
    return torch.bmm(f, f.transpose(1,2))


#### Data Splitting

In [7]:
def ensure_splits(class_name, val_frac=0.1, test_frac=0.05, seed=42):
    random.seed(seed)
    src_content = os.path.join(CONTENT_ROOT, class_name)
    src_style   = os.path.join(STYLE_ROOT, class_name)
    assert os.path.isdir(src_content), f"Missing content folder: {src_content}"
    assert os.path.isdir(src_style), f"Missing style folder: {src_style}"

    def split_list(files):
        n = len(files)
        n_val = int(n * val_frac)
        n_test = int(n * test_frac)
        return files[n_val+n_test:], files[:n_val], files[n_val:n_val+n_test]

    def copy_split(src_folder, dst_folder, files):
        os.makedirs(dst_folder, exist_ok=True)
        for f in files:
            src = os.path.join(src_folder, f)
            dst = os.path.join(dst_folder, f)
            if not os.path.exists(dst):
                Image.open(src).convert("RGB").save(dst, "JPEG", quality=90)

    for domain in ["content", "style"]:
        src = os.path.join(CONTENT_ROOT if domain=="content" else STYLE_ROOT, class_name)
        files = [f for f in os.listdir(src) if f.lower().endswith(('.jpg','.jpeg','.png'))]
        random.shuffle(files)
        train, val, test = split_list(files)
        for split, flist in zip(["train","val","test"], [train,val,test]):
            out_dir = os.path.join(SPLIT_ROOT, split, class_name)
            copy_split(src, out_dir, flist)

    print(f"Created train/val/test splits under {SPLIT_ROOT}/{class_name}")

ensure_splits(TARGET_CLASS)


Created train/val/test splits under ../Data/dataset/split/butterfly


#### Data Sampling

In [8]:
content_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor()
])
style_transform = content_transform


In [9]:
class SingleClassPairedSampler:
        def __init__(self, split_root, class_name):
            self.split_root = split_root
            self.class_name = class_name
            self._build_index()
            self.split = 'train'

        def _build_index(self):
            for split in ['train','val','test']:
                path = os.path.join(self.split_root, split, self.class_name)
                files = sorted([
                    os.path.join(path, f) for f in os.listdir(path)
                    if f.lower().endswith(('.jpg','.jpeg','.png'))
                ])
                setattr(self, f"{split}_files", files)

        def set_split(self, split): self.split = split

        def sample_batch(self, batch_size):
            files = getattr(self, f"{self.split}_files")
            paths_c = random.choices(files, k=batch_size)
            paths_s = random.choices(files, k=batch_size)
            c_imgs = [content_transform(exif_fix_and_open(p)) for p in paths_c]
            s_imgs = [style_transform(exif_fix_and_open(p)) for p in paths_s]
            return torch.stack(c_imgs), torch.stack(s_imgs)

sampler = SingleClassPairedSampler(SPLIT_ROOT, TARGET_CLASS)
print("Sampler ready, number of train images:", len(sampler.train_files))

Sampler ready, number of train images: 1242


#### Load Pre-Trained VGG

In [10]:
#load pretrained VGG (same form nst.py)
vgg = models.vgg19(pretrained=True).features.to(DEVICE).eval()
for p in vgg.parameters():
    p.requires_grad = False



#### Transformer Network

In [11]:
class ConvLayer(nn.Module): #processes the image and extracts features while keeping style consistent
    def __init__(self, in_c, out_c, kernel, stride):
        super().__init__()
        padding = kernel // 2
        self.conv = nn.Conv2d(in_c, out_c, kernel, stride, padding)
        self.inorm = nn.InstanceNorm2d(out_c, affine=True)
    def forward(self, x):
        return F.relu(self.inorm(self.conv(x)))


In [12]:
class ResidualBlock(nn.Module): #learn style modifications
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, 1, 1)
        self.in1 = nn.InstanceNorm2d(channels, affine=True)
        self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1)
        self.in2 = nn.InstanceNorm2d(channels, affine=True)
    def forward(self, x):
        out = F.relu(self.in1(self.conv1(x)))
        out = self.in2(self.conv2(out))
        return out + x

In [13]:
class UpsampleConv(nn.Module): # upsampling the image (making it bigger)
    def __init__(self, in_c, out_c, kernel, upsample=None):
        super().__init__()
        self.upsample = upsample
        padding = kernel // 2
        self.conv = nn.Conv2d(in_c, out_c, kernel, 1, padding)
        self.inorm = nn.InstanceNorm2d(out_c, affine=True)
    def forward(self, x):
        if self.upsample:
            x = F.interpolate(x, scale_factor=self.upsample, mode='nearest')
        return F.relu(self.inorm(self.conv(x)))

In [14]:
class TransformerNet(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.conv1 = ConvLayer(3, 32, 9, 1)
        self.conv2 = ConvLayer(32, 64, 3, 2)
        self.conv3 = ConvLayer(64, 128, 3, 2)
        self.res1 = ResidualBlock(128)
        self.res2 = ResidualBlock(128)
        self.res3 = ResidualBlock(128)
        self.res4 = ResidualBlock(128)
        self.res5 = ResidualBlock(128)
        self.up1 = UpsampleConv(128, 64, 3, upsample=2)
        self.up2 = UpsampleConv(64, 32, 3, upsample=2)
        self.conv_out = nn.Conv2d(32, 3, 9, 1, 4)
        
    def forward(self, x):
        y = self.conv1(x)
        y = self.conv2(y)
        y = self.conv3(y)
        y = self.res1(y)
        y = self.res2(y)
        y = self.res3(y)
        y = self.res4(y)
        y = self.res5(y)
        y = self.up1(y)
        y = self.up2(y)
        y = self.conv_out(y)
        return torch.sigmoid(y)

In [15]:
model = TransformerNet().to(DEVICE)
opt = optim.Adam(model.parameters(), lr=LR)

In [17]:
config = LAYER_CONFIGS[ACTIVE_LAYER_CONFIG]
content_layers = config['content']
style_layers = config['style']
style_weights = config['style_weights']

sampler.set_split('train')

global_step = 0
for epoch in range(1, NUM_EPOCHS+1):
    model.train()
    epoch_loss = 0
    
    
    # steps_per_epoch = max(100, len(sampler.train_files)//BATCH_SIZE)  
    
    #for quick run
    total_train_images = sampler.data['train']['content'].shape[0] if hasattr(sampler, 'data') else len(sampler.train_files)
    min_steps = 20
    max_steps = 80
    estimated = max(1, total_train_images // BATCH_SIZE)
    steps_per_epoch = min(max_steps, max(min_steps, estimated))
    print(f"Using steps_per_epoch = {steps_per_epoch} (total_train_images={total_train_images}, batch_size={BATCH_SIZE})")
    
    
    for step in range(steps_per_epoch):
        content_batch, style_batch = sampler.sample_batch(BATCH_SIZE)
        content_batch, style_batch = content_batch.to(DEVICE), style_batch.to(DEVICE)

        opt.zero_grad()
        output = model(content_batch)

        c_norm, s_norm, o_norm = map(normalize_for_vgg, [content_batch, style_batch, output])

        # content loss
        c_feats = extract_features_batch(c_norm, content_layers, vgg)
        o_feats = extract_features_batch(o_norm, content_layers, vgg)
        c_loss = sum(torch.mean((o_feats[l]-c_feats[l])**2) for l in content_layers)

        # style loss
        s_feats = extract_features_batch(s_norm, style_layers, vgg)
        s_grams = {l: gram_matrix_batch(s_feats[l]) for l in style_layers}
        o_feats_style = extract_features_batch(o_norm, style_layers, vgg)
        s_loss = 0
        for l in style_layers:
            Gs, Go = s_grams[l], gram_matrix_batch(o_feats_style[l])
            w = style_weights.get(l, 1.0)
            s_loss += w * torch.mean((Go - Gs)**2)

        # TV loss
        tv_loss = torch.mean(torch.abs(output[:, :, :, :-1] - output[:, :, :, 1:])) + \
                  torch.mean(torch.abs(output[:, :, :-1, :] - output[:, :, 1:, :]))

        total_loss = CONTENT_WEIGHT*c_loss + STYLE_WEIGHT*s_loss + TV_WEIGHT*tv_loss
        total_loss.backward()
        opt.step()

        epoch_loss += total_loss.item()
        global_step += 1

        if step % 50 == 0:
            print(f"[Epoch {epoch} Step {step}] Total {total_loss.item():.4f}")

        if step % 300 == 0:
            model.eval()
            with torch.no_grad():
                sample_out = model(content_batch[:1]).cpu()
                utils.save_image(sample_out, f"{SAMPLES_DIR}/ep{epoch}_step{global_step}.png")
            model.train()

    torch.save(model.state_dict(), f"{CHECKPOINT_DIR}/model_epoch{epoch}.pth")
    print(f"Epoch {epoch} done | Avg loss {epoch_loss/steps_per_epoch:.4f}")


Using steps_per_epoch = 80 (total_train_images=1242, batch_size=6)
[Epoch 1 Step 0] Total 174016131915317248.0000
[Epoch 1 Step 50] Total 257013316893802496.0000
Epoch 1 done | Avg loss 233098887127944384.0000
Using steps_per_epoch = 80 (total_train_images=1242, batch_size=6)
[Epoch 2 Step 0] Total 132619175532167168.0000
[Epoch 2 Step 50] Total 383216532840251392.0000
Epoch 2 done | Avg loss 225436890313418336.0000
Using steps_per_epoch = 80 (total_train_images=1242, batch_size=6)
[Epoch 3 Step 0] Total 95563520552206336.0000
[Epoch 3 Step 50] Total 218656269643284480.0000
Epoch 3 done | Avg loss 196145614612017984.0000
Using steps_per_epoch = 80 (total_train_images=1242, batch_size=6)
[Epoch 4 Step 0] Total 134578762950901760.0000
[Epoch 4 Step 50] Total 141924007430586368.0000
Epoch 4 done | Avg loss 213499081971780800.0000
Using steps_per_epoch = 80 (total_train_images=1242, batch_size=6)
[Epoch 5 Step 0] Total 205968403674955776.0000


KeyboardInterrupt: 

In [18]:
model = TransformerNet().to(DEVICE)
checkpoint_path = f"{CHECKPOINT_DIR}/model_epoch4.pth"
model.load_state_dict(torch.load(checkpoint_path))
model.eval()  # set to evaluation mode


TransformerNet(
  (conv1): ConvLayer(
    (conv): Conv2d(3, 32, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
    (inorm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  )
  (conv2): ConvLayer(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (inorm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  )
  (conv3): ConvLayer(
    (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (inorm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  )
  (res1): ResidualBlock(
    (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (in1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (in2): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
 

In [20]:
#test fior 1 image
from PIL import Image
import torch
from torchvision import transforms, utils

img = Image.open("test_imgs/butterfly.jpg").convert('RGB')

content_tensor = content_transform(img).unsqueeze(0).to(DEVICE)  

model.eval()
with torch.no_grad():
    output = model(content_tensor)
utils.save_image(output.cpu(), "stylized_image.png")
