In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from PIL import Image
import os

try:
    from google.colab import files
    COLAB = True
except:
    COLAB = False

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", DEVICE)
def load_image(path, max_size=512, shape=None):
    img = Image.open(path).convert("RGB").copy()
    img.load()

    size = max(img.size)
    if size > max_size:
        scale = max_size/size
        img = img.resize((int(img.size[0]*scale), int(img.size[1]*scale)), Image.LANCZOS)

    if shape is not None:
        img = img.resize((int(shape[0]), int(shape[1])), Image.LANCZOS)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.ConvertImageDtype(torch.float32),
        transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))
    ])
    return transform(img).unsqueeze(0).to(DEVICE)


class Gram(nn.Module):
    def forward(self, x):
        # x: [B,C,H,W] -> compute per-batch Gram [B, C, C]
        b, c, h, w = x.size()
        features = x.view(b, c, h * w)           # [B, C, H*W]
        G = torch.bmm(features, features.transpose(1, 2))  # [B, C, C]
        return G.div(c * h * w)

class ContentLoss(nn.Module):
    def __init__(self, target):
        super().__init__()
        self.target = target.detach()
        self.loss = 0.0
    def forward(self, x):
        self.loss = nn.functional.mse_loss(x, self.target)
        return x

class StyleLoss(nn.Module):
    def __init__(self, target):
        super().__init__()
        self.target = Gram()(target).detach()
        self.loss = 0.0
    def forward(self, x):
        G = Gram()(x)
        self.loss = nn.functional.mse_loss(G, self.target)
        return x

class Normalize(nn.Module):
    def __init__(self):
        super().__init__()
        mean = torch.tensor([0.485,0.456,0.406], device=DEVICE).view(1,3,1,1)
        std  = torch.tensor([0.229,0.224,0.225], device=DEVICE).view(1,3,1,1)
        self.mean = mean
        self.std = std
    def forward(self, img):
        return (img - self.mean) / self.std

def build_model(cnn, style, content):
    norm = Normalize().to(DEVICE)
    model = nn.Sequential(norm)
    style_losses = []
    content_losses = []

    i = 0
    for layer in cnn.children():
        if isinstance(layer, nn.Conv2d):
            i += 1
            name = f"conv_{i}"
        elif isinstance(layer, nn.ReLU):
            name = f"relu_{i}"
            layer = nn.ReLU(inplace=False)
        elif isinstance(layer, nn.MaxPool2d):
            name = f"pool_{i}"
        else:
            name = layer.__class__.__name__ + str(i)

        model.add_module(name, layer)

        if name == "conv_4":
            target = model(content).detach()
            cl = ContentLoss(target)
            model.add_module("content_loss_4", cl)
            content_losses.append(cl)

        if name in ["conv_1","conv_2","conv_3","conv_4","conv_5"]:
            target = model(style).detach()
            sl = StyleLoss(target)
            model.add_module(f"style_loss_{i}", sl)
            style_losses.append(sl)

    for j in range(len(model)-1, -1, -1):
        if isinstance(model[j], ContentLoss) or isinstance(model[j], StyleLoss):
            break
    model = model[:j+1]
    return model, style_losses, content_losses
def run_nst(content, style, steps=350, style_weight=3e2, content_weight=1.0, lr=0.02):
    cnn = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features.to(DEVICE).eval()
    model, style_losses, content_losses = build_model(cnn, style, content)

    input_img = torch.randn_like(content).to(DEVICE).requires_grad_(True)
    opt = optim.Adam([input_img], lr=lr)

    for step in range(1, steps+1):
        opt.zero_grad()
        model(input_img)

        s_loss = sum(sl.loss for sl in style_losses)
        c_loss = sum(cl.loss for cl in content_losses)
        loss = style_weight * s_loss + content_weight * c_loss

        if torch.isnan(loss):
            print("NaN encountered in loss. Stopping.")
            break

        loss.backward()
        torch.nn.utils.clip_grad_norm_([input_img], 10.0)
        opt.step()

        with torch.no_grad():
            inv_std = 1.0 / torch.tensor([0.229,0.224,0.225], device=DEVICE).view(1,3,1,1)
            inv_mean = torch.tensor([0.485,0.456,0.406], device=DEVICE).view(1,3,1,1)
            tmp = input_img * torch.tensor([0.229,0.224,0.225], device=DEVICE).view(1,3,1,1) + torch.tensor([0.485,0.456,0.406], device=DEVICE).view(1,3,1,1)
            tmp = tmp.clamp(0,1)

            input_img.copy_((tmp - inv_mean * 0.0) / torch.tensor([0.229,0.224,0.225], device=DEVICE).view(1,3,1,1))

        if step % 50 == 0 or step == 1:
            print(f"Step {step}/{steps}  Style Loss: {s_loss.item():.4f}  Content Loss: {c_loss.item():.4f}")


    with torch.no_grad():
        out = input_img.detach().cpu().squeeze(0)
        out = out * torch.tensor([0.229,0.224,0.225]).view(3,1,1) + torch.tensor([0.485,0.456,0.406]).view(3,1,1)
        out = out.clamp(0,1)
    return out


class TransformerNet(nn.Module):
    def __init__(self):
        super().__init__()
        def conv(in_c,out_c,k,stride):
            pad = k//2
            return nn.Sequential(
                nn.Conv2d(in_c,out_c,k,stride,pad),
                nn.InstanceNorm2d(out_c,affine=True),
                nn.ReLU(True)
            )
        self.model = nn.Sequential(
            conv(3,32,9,1),
            conv(32,64,3,2),
            conv(64,128,3,2),
            nn.Conv2d(128,128,3,1,1),
            nn.ReLU(True),
            nn.ConvTranspose2d(128,64,3,2,1,1),
            nn.ReLU(True),
            nn.ConvTranspose2d(64,32,3,2,1,1),
            nn.ReLU(True),
            nn.Conv2d(32,3,9,1,4)
        )
    def forward(self,x): return self.model(x)


def fast_style(model_path, input_path, out_path="fast_output.png"):
    img = Image.open(input_path).convert("RGB")
    t = transforms.ToTensor()
    x = t(img).unsqueeze(0).to(DEVICE)

    net = TransformerNet().to(DEVICE)
    state = torch.load(model_path, map_location=DEVICE)

    if isinstance(state, dict) and 'state_dict' in state:
        net.load_state_dict(state['state_dict'])
    else:
        net.load_state_dict(state)
    net.eval()
    with torch.no_grad():
        y = net(x)
        y = y.clamp(0,1).cpu().squeeze(0)
    out = transforms.ToPILImage()(y)
    out.save(out_path)
    print("Saved", out_path)

if COLAB:
    print("Upload content + style images...")
    up = files.upload()
    paths = list(up.keys())
    if len(paths) >= 2:
        content = load_image(paths[0])

        style = load_image(paths[1], shape=(content.size(3), content.size(2)))
        out = run_nst(content, style, steps=300, style_weight=3e2, content_weight=1.0, lr=0.02)

        out_img = (out.permute(1,2,0).numpy() * 255).astype('uint8')
        Image.fromarray(out_img).save("nst_output.png")
        print("Saved nst_output.png")
    else:
        print("Please upload at least two files: content (first) then style (second).")

else:
    print("Not running in Colab. Call functions directly, e.g.:")
    print(" content = load_image('content.jpg')\n style   = load_image('style.jpg', shape=(content.size(3), content.size(2)))\n out = run_nst(content, style)\n # save out with the same conversion used above")


Using: cpu
Upload content + style images...


KeyboardInterrupt: 