In [16]:
!pip install flask pyngrok torch torchvision torchaudio segmentation-models-pytorch efficientnet-pytorch pillow matplotlib flask-cors



In [19]:
from flask import Flask, request, jsonify
from flask_cors import CORS
from pyngrok import ngrok
import torch
from torchvision import transforms
import torchvision.models as models
from torchvision.models.segmentation import deeplabv3_resnet50
from PIL import Image
import numpy as np
import io
import base64
from matplotlib import pyplot as plt
import torch.nn as nn

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

DEEP_LAB_IMG_SIZE = 512


NGROK_TOKEN = "2fjQrXDAOPGX3iQKjJR8HFvgSBD_QhXbQH4WB8G7HsBZgpQE"
try:
    ngrok.kill()
except Exception:
    pass

try:
    ngrok.set_auth_token(NGROK_TOKEN)
    public_url = ngrok.connect(5000)
    print(" * Public URL:", public_url)
except Exception as e:
    print("Ngrok error (continuing local):", e)


app = Flask(__name__)
CORS(app)


def array_to_base64_png(arr):
    """arr: HxW or HxWx3 uint8 numpy array"""
    buf = io.BytesIO()
    if arr.ndim == 2:
        # grayscale -> convert to 3-channel for PNG preview
        plt.imsave(buf, arr, cmap="gray", format="png")
    else:
        Image.fromarray(arr).save(buf, format="PNG")
    buf.seek(0)
    return base64.b64encode(buf.read()).decode("utf-8")

# ------------------------------
# 1) SEGMENTATION: DeepLabV3+ (ResNet50)
# ------------------------------
def load_deeplabv3_resnet50(weights_path: str, device: torch.device):

    model = deeplabv3_resnet50(pretrained=False, num_classes=1, aux_loss=None)
    ckpt = torch.load(weights_path, map_location="cpu")

    if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
        state = ckpt["model_state_dict"]
    else:
        state = ckpt

    new_state = {}
    for k, v in state.items():
        nk = k.replace("module.", "") if k.startswith("module.") else k
        new_state[nk] = v

    model.load_state_dict(new_state, strict=False)
    model.to(device)
    model.eval()
    return model

DEEP_LAB_CHECKPOINT = "/kaggle/input/new/other/default/1/best_deeplabv3_defacto_last.pth"  # <- change if needed

print("Loading DeepLab model...")
try:
    seg_model = load_deeplabv3_resnet50(DEEP_LAB_CHECKPOINT, DEVICE)
    print("DeepLab model loaded.")
except Exception as ex:
    print("Failed to load DeepLab checkpoint:", ex)
    raise

# DeepLab preprocessing (PIL image in -> tensor)
seg_preprocess = transforms.Compose([
    transforms.Resize((DEEP_LAB_IMG_SIZE, DEEP_LAB_IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
])

def deeplab_predict_and_overlay(pil_img: Image.Image, threshold: float = 0.5):

    orig_w, orig_h = pil_img.size
    # preprocess
    input_tensor = seg_preprocess(pil_img).unsqueeze(0).to(DEVICE)  # 1x3xHxW

    with torch.no_grad():
        out = seg_model(input_tensor)
        # some implementations return dict with 'out'
        if isinstance(out, dict) and "out" in out:
            logits = out["out"]
        else:
            logits = out

        # logits shape: 1 x 1 x H x W  (we set num_classes=1)
        probs = torch.sigmoid(logits)
        mask = (probs > threshold).squeeze(0).squeeze(0).cpu().numpy().astype(np.uint8)  # HxW

    mask_resized = ( (np.array(Image.fromarray(mask*255).resize((orig_w, orig_h), resample=Image.NEAREST)) ) > 127 ).astype(np.uint8)

    img_np = np.array(pil_img).astype(np.uint8)
    overlay = img_np.copy()
    alpha = 0.4
    green = np.array([0, 255, 0], dtype=np.uint8)

    mask_bool = mask_resized.astype(bool)
    overlay[mask_bool] = (overlay[mask_bool].astype(np.float32) * (1 - alpha) + green * alpha).astype(np.uint8)

    masked_img = img_np.copy()
    masked_img[~mask_bool] = 0

    return mask_resized, overlay, masked_img

@app.route("/segment", methods=["POST"])
def segment_route():
    if "image" not in request.files:
        return jsonify({"error": "No image uploaded"}), 400


    img_file = request.files["image"]
    pil_img = Image.open(img_file).convert("RGB")

    mask_resized, overlay, masked_img = deeplab_predict_and_overlay(pil_img, threshold=0.5)

    return jsonify({
        "mask": array_to_base64_png((mask_resized * 255).astype(np.uint8)),
        "overlay": array_to_base64_png(overlay),
        "masked_image": array_to_base64_png(masked_img)
    })


# ------------------------------
# 2) CLASSIFICATION (ResNet50) â€” 
# ------------------------------
CLASS_NAMES = {
    0: "splicing",
    1: "copymove",
    2: "inpainting",
    3: "face"
}

clf_model = models.resnet50(pretrained=False)
clf_model.fc = nn.Linear(2048, 4)


CLASSIF_CHECKPOINT = "/kaggle/input/deeplearing/other/default/1/best_resnet50_defacto.pth"

print("Loading classification model...")
ckpt = torch.load(CLASSIF_CHECKPOINT, map_location="cpu")
clean = {k.replace("module.", ""): v for k, v in ckpt.items()}
clf_model.load_state_dict(clean, strict=False)
clf_model.to(DEVICE)
clf_model.eval()
print("Classification model loaded.")

clf_preprocess = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485,0.456,0.406),
                         std=(0.229,0.224,0.225))
])

@app.route("/classify", methods=["POST"])
def classify_route():
    if "image" not in request.files:
        return jsonify({"error": "No image uploaded"}), 400

    img_file = request.files["image"]
    pil_img = Image.open(img_file).convert("RGB")
    tensor = clf_preprocess(pil_img).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        out = clf_model(tensor)
        probs = torch.softmax(out, dim=1)
        idx = int(probs.argmax(dim=1).cpu().numpy()[0])
        confidence = float(probs[0, idx].cpu().numpy())

    return jsonify({
        "class_id": idx,
        "class_name": CLASS_NAMES.get(idx, str(idx)),
        "confidence": confidence
    })

# ------------------------------
# Run server
# ------------------------------
if __name__ == "__main__":
    app.run(port=5000)

Using device: cuda
 * Public URL: NgrokTunnel: "https://452334c7089b.ngrok-free.app" -> "http://localhost:5000"
Loading DeepLab model...
DeepLab model loaded.
Loading classification model...
Classification model loaded.
 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on http://127.0.0.1:5000
Press CTRL+C to quit
