# Virtual Try-On Inference Orchestrator

This notebook orchestrates the full virtual try-on pipeline, from pre-processing (SCHP, DensePose, OpenPose) to flow estimation and final garment warping using FVNT.

In [None]:
import os, sys, torch, shutil, numpy as np, math
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
from torchvision.ops import deform_conv2d

# Setup paths
PROJECT_ROOT = "/content/drive/MyDrive/virtual_tryon_project"
if PROJECT_ROOT not in sys.path: sys.path.insert(0, PROJECT_ROOT)
if f"{PROJECT_ROOT}/src" not in sys.path: sys.path.insert(0, f"{PROJECT_ROOT}/src")

from agnostic_logic import get_agnostic_person

print("✅ Environment ready")

## 1. FVNT Setup & Injection
Integrating the Zero-Build DCN and loading FEM weights.

In [None]:
FVNT_DIR = f"{PROJECT_ROOT}/FVNT"
DCN_FOLDER = os.path.join(FVNT_DIR, "Deformable")
os.makedirs(DCN_FOLDER, exist_ok=True)
with open(os.path.join(DCN_FOLDER, "__init__.py"), "w") as f: f.write("from .modules import DeformConvPack")

modules_py = """
import torch
from torch import nn
from torchvision.ops import deform_conv2d
import math

class DeformConvPack(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
                 dilation=1, groups=1, deformable_groups=1, im2col_step=64, bias=True, lr_mult=0.1):
        super(DeformConvPack, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
        self.stride = (stride, stride) if isinstance(stride, int) else stride
        self.padding = (padding, padding) if isinstance(padding, int) else padding
        self.dilation = (dilation, dilation) if isinstance(dilation, int) else dilation
        self.groups = groups
        self.deformable_groups = deformable_groups
        
        self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
        if bias: self.bias = nn.Parameter(torch.Tensor(out_channels))
        else: self.register_parameter('bias', None)
            
        self.conv_offset = nn.Conv2d(in_channels, 
                                     deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
                                     kernel_size=self.kernel_size, 
                                     stride=self.stride, 
                                     padding=self.padding)
        self.reset_parameters()

    def reset_parameters(self):
        n = self.in_channels
        for k in self.kernel_size: n *= k
        stdv = 1. / math.sqrt(n)
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None: self.bias.data.uniform_(-stdv, stdv)
        self.conv_offset.weight.data.zero_()
        self.conv_offset.bias.data.zero_()

    def forward(self, x):
        offset = self.conv_offset(x)
        return deform_conv2d(x, offset, self.weight, self.bias, 
                             stride=self.stride, padding=self.padding, dilation=self.dilation)
"""
with open(os.path.join(DCN_FOLDER, "modules.py"), "w") as f: f.write(modules_py)

if FVNT_DIR not in sys.path: sys.path.insert(0, FVNT_DIR)
from mine.network_stage_2_mine_x2_resflow import Stage_2_generator
print("✅ DCN Injected & Generator Imported")

## 2. Load Model & Helpers

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
fem = Stage_2_generator(20).to(device)
ckpt_path = f"{PROJECT_ROOT}/checkpoints/stage2_model.pth"
if os.path.exists(ckpt_path):
    ckpt = torch.load(ckpt_path, map_location=device)
    fem.load_state_dict(ckpt['G'] if 'G' in ckpt else ckpt)
    fem.eval()
    print("✅ FEM Loaded")

def warp_high_res(img_t, low_res_flow):
    B, _, H_hr, W_hr = img_t.shape
    _, _, H_lr, W_lr = low_res_flow.shape
    flow_hr = torch.nn.functional.interpolate(low_res_flow, size=(H_hr, W_hr), mode='bilinear', align_corners=True)
    flow_hr[:, 0] = flow_hr[:, 0] * (W_hr / W_lr)
    flow_hr[:, 1] = flow_hr[:, 1] * (H_hr / H_lr)
    gx = torch.arange(W_hr, device=device).view(1,-1).repeat(H_hr,1).view(1,1,H_hr,W_hr).expand(B,-1,-1,-1)
    gy = torch.arange(H_hr, device=device).view(-1,1).repeat(1,W_hr).view(1,1,H_hr,W_hr).expand(B,-1,-1,-1)
    grid = torch.cat([gx, gy], 1).float() + flow_hr
    grid[:, 0] = 2.0 * grid[:, 0] / max(W_hr - 1, 1) - 1.0
    grid[:, 1] = 2.0 * grid[:, 1] / max(H_hr - 1, 1) - 1.0
    return torch.nn.functional.grid_sample(img_t, grid.permute(0, 2, 3, 1), align_corners=True)

def prep_tensor(img, is_parsing=False):
    W_MODEL, H_MODEL = 192, 256
    img_resized = img.resize((W_MODEL, H_MODEL), Image.NEAREST if is_parsing else Image.BILINEAR)
    if is_parsing:
        lbl = np.array(img_resized)
        out = torch.zeros(20, H_MODEL, W_MODEL)
        for i in [4, 5, 6, 7]: out[i] = torch.from_numpy((lbl == i).astype(np.float32))
        return out.unsqueeze(0)
    else:
        return (torch.from_numpy(np.array(img_resized)).permute(2,0,1).float().unsqueeze(0)/127.5-1)

## 3. Pipeline Execution
Running pre-processors, agnostic generation, and final warping.

In [None]:
def run_pipeline(person_id, cloth_id):
    # 1. Paths
    person_img_path = f"{PROJECT_ROOT}/data/viton_hd/test/image/{person_id}"
    cloth_img_path = f"{PROJECT_ROOT}/data/viton_hd/test/cloth/{cloth_id}"
    
    # 2. Run Pre-processors (Simulated here, in practice use !python scripts)
    parse_path = f"{PROJECT_ROOT}/data/viton_hd/test/image-parse-v3/{person_id.replace('.jpg','.png')}"
    
    # 3. Generate Agnostic
    agnostic_img, agnostic_parse = get_agnostic_person(person_img_path, parse_path)
    
    # 4. FEM Inference
    input_1 = prep_tensor(agnostic_parse, is_parsing=True).to(device)
    cloth_mask = Image.open(f"{PROJECT_ROOT}/data/viton_hd/test/cloth-mask/{cloth_id}")
    input_2 = torch.zeros(1, 20, 256, 192).to(device)
    input_2[0, 5] = torch.from_numpy((np.array(cloth_mask.resize((192, 256))) >= 128).astype(np.float32))
    
    with torch.no_grad():
        flow_list, _ = fem(input_1, input_2)
    low_res_flow = flow_list[-1]
    
    # 5. High-Res Warping
    cloth_hd = Image.open(cloth_img_path).resize((768, 1024))
    cloth_hd_t = (torch.from_numpy(np.array(cloth_hd)).permute(2,0,1).float().unsqueeze(0)/127.5-1).to(device)
    warped_hd = warp_high_res(cloth_hd_t, low_res_flow)
    
    # Convert for display
    res = ((warped_hd[0].permute(1,2,0).cpu().numpy()+1)*0.5).clip(0,1)
    plt.imshow(res); plt.axis('off'); plt.show()
    
    return res